clc; 
clear all
close all

M1=8;
M2=8;
N1=16;
N2=16;
M=M1*M2; % the number of the antennas at the BS
N=N1*N2; % the number of RIS elements
K=16;    % the number of users

L1=5;    % the number of paths between the BS and the RIS
L2=8;    % the number of paths between the RIS and the user

L2c1=0;  % the number of common paths between the RIS and all users
L2c2=4;  % the number of common paths between the RIS and all users
L2c3=6;  % the number of common paths between the RIS and all users
L2c4=8;  % the number of common paths between the RIS and all users
 
Q_all=[32:16:128]; % the number of time slots for the channel estimation
d=0.5; % the antenna spacing lambda/2

C0 = 0.001;
d1 = 10; % the distance between the BS and the RIS
d2 = 100; % the distance between the RIS and the user
alpha1 = -2.2;
alpha2 = -2.8;
p1 = C0*d1^alpha1;
p2 = C0*d2^alpha2;


UN1=(1/sqrt(N1))*exp(-1i*2*pi*[0:N1-1]'*d*[0:N1-1]*(2/N1));
UN2=(1/sqrt(N2))*exp(-1i*2*pi*[0:N2-1]'*d*[0:N2-1]*(2/N2));
UN=kron(UN1,UN2);

UM1=(1/sqrt(M1))*exp(-1i*2*pi*[0:M1-1]'*d*[-(M1-1)/2:1:(M1/2)]*(2/M1));
UM2=(1/sqrt(M2))*exp(-1i*2*pi*[0:M2-1]'*d*[-(M2-1)/2:1:(M2/2)]*(2/M2));
UM=kron(UM1,UM2);


SNR_dB=0;  
SNR_linear=10.^(SNR_dB/10.);

sample=20;
len=length(Q_all);

error0=zeros(sample,len);
error1=zeros(sample,len);
error2=zeros(sample,len);
error3=zeros(sample,len);
error4=zeros(sample,len);
error5=zeros(sample,len);
error6=zeros(sample,len);
energy=zeros(sample,1);

for s=1:sample
    s
    sigma2=(p1*p2)/(SNR_linear);   
    G=sqrt(p1)*G_channel(M1,M2,N1,N2,L1);
    hK=sqrt(p2)*hr_channel(N1,N2,K,L2,L2);
    H=zeros(N,M,K);
    for k=1:K
        hr=hK(:,k);
        HC=G*diag(hr);
        H(:,:,k)=(UM'*HC*(UN.')')'; 
%         figure()
%         plot(abs(H(:,:,k)))
%         a=1;
    end
    for iQ=1:len
        Q=Q_all(iQ);
        Y=zeros(Q,M,K);
        W=((rand(N,Q)>0.5)*2-1)/sqrt(N);
        A=(UN*W)';
        for k=1:K
             noise = sqrt(sigma2)*(randn(Q,M)+1i*randn(Q,M))/sqrt(2);
             Y(:,:,k)=A*H(:,:,k)+noise;
        end
        
        %% the Oracle LS based scheme
        [Hhat0,row0,column0]=OracleOMP(Y,H,A,M,N,K,L1,L2);
        
        %% the conventional OMP based scheme
        Hhat1=OMP(Y,A,M,N,K,L1,L2);
        
        %% the row-structured sparsity based scheme
        [Hhat2,row2,column2]=SS_OMP(Y,A,M,N,K,L1,L2);
        
        %% the proposed DS-OMP based scheme
        [Hhat3,row3,column3]=DS_OMP(Y,A,M,N,K,L1,L2,L2c1);  

               
        [Hhat4,row4,column4]=DS_OMP(Y,A,M,N,K,L1,L2,L2c2);  

        [Hhat5,row5,column5]=DS_OMP(Y,A,M,N,K,L1,L2,L2c3);  

        [Hhat6,row6,column6]=DS_OMP(Y,A,M,N,K,L1,L2,L2c4);  

        
        error0(s,iQ)=sum(sum(sum(abs(Hhat0-H).^2)));
        error1(s,iQ)=sum(sum(sum(abs(Hhat1-H).^2)));
        error2(s,iQ)=sum(sum(sum(abs(Hhat2-H).^2)));
        error3(s,iQ)=sum(sum(sum(abs(Hhat3-H).^2)));
        error4(s,iQ)=sum(sum(sum(abs(Hhat4-H).^2)));
        error5(s,iQ)=sum(sum(sum(abs(Hhat5-H).^2)));
        error6(s,iQ)=sum(sum(sum(abs(Hhat6-H).^2)));
      
    end
    energy(s)=sum(sum(sum(abs(H).^2)));
end
 
nmse0=mean(error0)/mean(energy)
nmse1=mean(error1)/mean(energy)
nmse2=mean(error2)/mean(energy)
nmse3=mean(error3)/mean(energy)
nmse4=mean(error4)/mean(energy)
nmse5=mean(error5)/mean(energy)
nmse6=mean(error6)/mean(energy)



nmse0=10*log10(nmse0)
nmse1=10*log10(nmse1)
nmse2=10*log10(nmse2)
nmse3=10*log10(nmse3)
nmse4=10*log10(nmse4)
nmse5=10*log10(nmse5)
nmse6=10*log10(nmse6)


figure('color',[1,1,1]); 
ha=gca;
plot(Q_all,nmse1,'m>-',...
    Q_all,nmse2,'b^-',...
    Q_all,nmse3,'r<-',...
    Q_all,nmse4,'ro-',...
    Q_all,nmse5,'rs-',...
    Q_all,nmse6,'rd-',...
    Q_all,nmse0,'k--','linewidth',1.5);
legend('Conventional CS based scheme [10]','Row-structured sparsity based scheme [11]','Proposed DS-OMP based scheme (Lc=0)','Proposed DS-OMP based scheme (Lc=4)','Proposed DS-OMP based scheme (Lc=6)','Proposed DS-OMP based scheme (Lc=8)','Oracle LS based scheme') 
grid on
xlabel('The pilot overhead Q for the cascaded channel estimation')
ylabel('NMSE (dB)')
set(ha,'xtick',[32:16:128])
set(ha,'ytick',[-25:5:-5])
axis([32,128,-25,-5])
